This is an example data analytics case study and was presented in the lecture on 12.12.2019.
In the Analytics Cup, you will receive a few files including
This file is an R-markdown (.Rmd) notebook used to make presentations and reports containing R-Code, output generated by that code and text. In RStudio, you can use the Knit function (button in top toolbar) to compile it to an .html or .pdf report.
You should use this file to understand what a typical analytics workflow looks like. During the AC, you may want to use this file as a reference to get started or to look up useful functions and packages. In particular, we demonstrate * How to make some exploratory plots using ggplot2. * How the meta-machine learning package mlr can be used to apply many different R-machine learning packages with a common interface.
Note: _In the analytics cup, you will have to submit the script that generated your submission. This should be a regular R script, not a R-Notebook.
mlr package has to offer. You are advised to read many of the tutorials found at https://mlr.mlr-org.com/.First, we will describe the challenge for the case study and the files you would be provided in this case. After that, we will go through the entire workflow once, highlighting useful features of different packages that you may want to use. In particular we
mlrmlrmlrThe dataset in this case study is based on the Global Power Plant Database commissioned by the World Resources Institute: http://datasets.wri.org/dataset/globalpowerplantdatabase
Citation:
Global Energy Observatory, Google, KTH Royal Institute of Technology in Stockholm, Enipedia, World Resources Institute. 2018. Global Power Plant Database. Published on Resource Watch and Google Earth Engine; http://resourcewatch.org/ https://earthengine.google.com/
GitHub: >https://github.com/wri/global-power-plant-database
README.md a file that clearly outlines your task.global_power_plants.csv - the datadata_info.txt - some background information about the datasubmission_template.csv - the test set and/or a template that specifies the format of a valid submission.These would not be given in the actual AC but are provided in this case study to give you some more information.
.Rmd and .html).Let’s start by loading required packages
library(tidyverse)
## ── Attaching packages ────── tidyverse 1.2.1 ──
## ✔ ggplot2 3.2.1 ✔ purrr 0.3.2
## ✔ tibble 2.1.3 ✔ dplyr 0.8.3
## ✔ tidyr 1.0.0 ✔ stringr 1.4.0
## ✔ readr 1.3.1 ✔ forcats 0.4.0
## ── Conflicts ───────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(lubridate)
##
## Attaching package: 'lubridate'
## The following object is masked from 'package:base':
##
## date
library(summarytools)
## Registered S3 method overwritten by 'pryr':
## method from
## print.bytes Rcpp
## system has no X11 capabilities, therefore only ascii graphs will be produced by dfSummary()
##
## Attaching package: 'summarytools'
## The following object is masked from 'package:tibble':
##
## view
library(ggmap)
## Google's Terms of Service: https://cloud.google.com/maps-platform/terms/.
## Please cite ggmap if you use it! See citation("ggmap") for details.
library(mlr) # you might need to install more packages depending on what learner you use in mlr!
## Loading required package: ParamHelpers
options(dplyr.width = Inf) # show all columns when printing to console
theme_set(theme_minimal()) # set ggplot theme for cleaner plotting
set.seed(2019) # in the AC, you'll be required to set a fixed random seed to make your work reproducible
read_csv('global_power_plants.csv')
## Parsed with column specification:
## cols(
## .default = col_character(),
## capacity_mw = col_double(),
## latitude = col_double(),
## longitude = col_double(),
## other_fuel3 = col_logical(),
## commissioning_year = col_double(),
## wepp_id = col_double(),
## year_of_capacity_data = col_double(),
## generation_gwh_2013 = col_double(),
## generation_gwh_2014 = col_double(),
## generation_gwh_2015 = col_double(),
## generation_gwh_2016 = col_double(),
## generation_gwh_2017 = col_double()
## )
## See spec(...) for full column specifications.
## Warning: 194 parsing failures.
## row col expected actual file
## 5503 wepp_id no trailing characters |1082331 'global_power_plants.csv'
## 12015 wepp_id no trailing characters |1085384 'global_power_plants.csv'
## 12042 wepp_id no trailing characters |1064452 'global_power_plants.csv'
## 12340 wepp_id no trailing characters |1060668 'global_power_plants.csv'
## 12453 wepp_id no trailing characters | 1074143| 1030932 'global_power_plants.csv'
## ..... ....... ...................... .................. .........................
## See problems(...) for more details.
Note: If youre reading the html-report, you will only see the first 10k rows, there’s actually 29910 rows in the dataset.
There’s been problems reading the data (see warnings), let’s check why
read_csv('global_power_plants.csv') %>% problems
## Parsed with column specification:
## cols(
## .default = col_character(),
## capacity_mw = col_double(),
## latitude = col_double(),
## longitude = col_double(),
## other_fuel3 = col_logical(),
## commissioning_year = col_double(),
## wepp_id = col_double(),
## year_of_capacity_data = col_double(),
## generation_gwh_2013 = col_double(),
## generation_gwh_2014 = col_double(),
## generation_gwh_2015 = col_double(),
## generation_gwh_2016 = col_double(),
## generation_gwh_2017 = col_double()
## )
## See spec(...) for full column specifications.
## Warning: 194 parsing failures.
## row col expected actual file
## 5503 wepp_id no trailing characters |1082331 'global_power_plants.csv'
## 12015 wepp_id no trailing characters |1085384 'global_power_plants.csv'
## 12042 wepp_id no trailing characters |1064452 'global_power_plants.csv'
## 12340 wepp_id no trailing characters |1060668 'global_power_plants.csv'
## 12453 wepp_id no trailing characters | 1074143| 1030932 'global_power_plants.csv'
## ..... ....... ...................... .................. .........................
## See problems(...) for more details.
Looks like there have been parsing failures for the columns wepp_id and other_fuel3. Since we care about the latter, we will fix the type in the column specification:
df <- read_csv('global_power_plants.csv',
col_types = cols(other_fuel3 = col_character(), wepp_id = col_character()))
df
Now the data looks like it has been read in successfully.
The summarytools package contains a useful tool to quickly give you a report of your data:
Note: The output won’t be properly displayed if you’re viewing the .html-report, run the source code yourself to open the summary-report in a browser.
library(summarytools)
df %>% dfSummary %>% view()
## Switching method to 'browser'
## Output file written: /tmp/Rtmp2Ute1H/file325f6b7790eb.html
There’s some columns we definitely won’t need in this analysis (see data_info.txt), let’s drop them now.
df <- df %>% select(-source, -url, -geolocation_source, -wepp_id,)
df %>% sample_n(10)
Let’s fix the data in commissioning_year - the column contains decimal numbers, not just full years.
df <- df %>%
mutate(
commissioning_year = as_date(date_decimal(commissioning_year))
) %>%
rename(
id = gppd_idnr,
commission_date = commissioning_year
)
df
df %>%
# only use countries with at least 80 entries
group_by(country) %>% filter(n() > 80) %>% ungroup() %>%
ggplot( aes(x=country)) +
geom_bar() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
df %>% ggplot(aes(x=primary_fuel, fill=primary_fuel)) +
geom_bar()+
theme(axis.text.x = element_text(angle = 45, hjust = 1))
df %>% ggplot(aes(y=capacity_mw, x= primary_fuel, fill = primary_fuel)) +
geom_boxplot() +
scale_y_log10() +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
countries_by_capacity <- df %>% group_by(country) %>% summarize(cap = sum(capacity_mw))
top_10_countries <- countries_by_capacity %>% top_n(10, cap) %>% pull(country)
df %>% filter(country %in% top_10_countries) %>% ggplot(aes(x=country, y=capacity_mw, fill=primary_fuel)) +
geom_bar(stat='identity') +
scale_fill_viridis_d()
df %>% filter(country %in% top_10_countries) %>% ggplot(aes(x=country, y=capacity_mw, fill=primary_fuel)) +
geom_bar(stat='identity', position='fill') +
scale_y_continuous(labels= scales::percent_format())
For all nuclear power plants in the U.S., make a line-chart of the development of annual power generation.
Hint: Use tidyr::pivot_longer for data preparation to get separate variables ‘observation_year’ and ‘observation_value’ per power station, then ggplot with geom_line using aes(..., group=name).
plot <- df %>%
# disregard missing values when making the plot scales
filter(!is.na(generation_gwh_2017)) %>%
ggplot(aes(x=capacity_mw, y=generation_gwh_2017, col = year(commission_date),
# we can define extra aesthetics that we won't use (see below)
label=name, label2=country)) +
geom_point() +
# plot line indicating full capacity being used (i.e. GWh = MW/1000 * 24*365.25 h)
stat_function(fun = (function(x) 24*365.25/1000*x), col='grey') +
facet_wrap(~primary_fuel, scales='free') +
theme(#legend.position = 'None',
axis.text.x = element_text(angle = 45, hjust = 1)) +
scale_color_viridis_c()
plot
plotlyplotly::ggplotly(plot, tooltip = c('x', 'y','label', 'label2', 'col'))
In this challenge, we’re asked to estimate power generation in 2017 whenever that data is missing, so we can create the test set ourselves.
train <- df %>% filter(!is.na(generation_gwh_2017))
test <- df %>% anti_join(train)
## Joining, by = c("country", "country_long", "name", "id", "capacity_mw", "latitude", "longitude", "primary_fuel", "other_fuel1", "other_fuel2", "other_fuel3", "commission_date", "owner", "year_of_capacity_data", "generation_gwh_2013", "generation_gwh_2014", "generation_gwh_2015", "generation_gwh_2016", "generation_gwh_2017")
At this point, we should check whether the test dataset matches the submission template. (If it doesn’t we did something wrong.)
submission_template <- read_csv('submission_template.csv')
## Parsed with column specification:
## cols(
## id = col_character(),
## prediction = col_double()
## )
template_ids <- submission_template %>% arrange(id) %>% pull(id)
test_ids <- test %>% arrange(id) %>% pull(id)
all(template_ids == test_ids)
## [1] TRUE
(looks good)
You should always check this - if there are big differences * you might not be able to train a model at all (e.g. different columns, different factor-levels) * your model will necessarily be biased if the distribution of features are different.
The package ggmap offers some nice functionality for plotting geographical data. The easiest way is ggmap::qmap('world') , qmap('Germany'), qmap('Tokyo') etc, to download map data, but you will need a personal Google Maps API key to use qmap. See ?qmap for details. Note: Make sure not to download too many tiles per month from google or you will incurr Google Cloud charges!
Instead, we will use open source map material from Stamen here, which has unlimited free calls but requires manual configuration of the plotted area.
left = -180
right = 180
bottom = -60# min(df$latitude)
top = max(df$latitude)
bbox = c(left=left, bottom=bottom, right=right, top=top)
world <- ggmap::get_stamenmap(bbox, zoom = 2, maptype = 'terrain-background')
## Source : http://tile.stamen.com/terrain-background/2/0/0.png
## Source : http://tile.stamen.com/terrain-background/2/1/0.png
## Source : http://tile.stamen.com/terrain-background/2/2/0.png
## Source : http://tile.stamen.com/terrain-background/2/3/0.png
## Source : http://tile.stamen.com/terrain-background/2/4/0.png
## Not Found (HTTP 404). Failed to aquire tile /terrain-background/2/4/0.png.
## Source : http://tile.stamen.com/terrain-background/2/0/1.png
## Source : http://tile.stamen.com/terrain-background/2/1/1.png
## Source : http://tile.stamen.com/terrain-background/2/2/1.png
## Source : http://tile.stamen.com/terrain-background/2/3/1.png
## Source : http://tile.stamen.com/terrain-background/2/4/1.png
## Not Found (HTTP 404). Failed to aquire tile /terrain-background/2/4/1.png.
## Source : http://tile.stamen.com/terrain-background/2/0/2.png
## Source : http://tile.stamen.com/terrain-background/2/1/2.png
## Source : http://tile.stamen.com/terrain-background/2/2/2.png
## Source : http://tile.stamen.com/terrain-background/2/3/2.png
## Source : http://tile.stamen.com/terrain-background/2/4/2.png
## Not Found (HTTP 404). Failed to aquire tile /terrain-background/2/4/2.png.
# see ??get_map for more options
ggmap(world)
We can now use our downloaded world map and plot things on top of it.
map_data <- df %>% mutate(set = if_else(id %in% train$id, 'train', 'test'))
ggmap(world, ggplot()) +
geom_point(aes(x=longitude, y=latitude, col=set), size=0.5, data=map_data) +
ylim(-60, top) +
theme(legend.position = 'top')
## Warning in if (extent == "normal") {: Bedingung hat Länge > 1 und nur das
## erste Element wird benutzt
## Warning in if (extent == "panel") {: Bedingung hat Länge > 1 und nur das
## erste Element wird benutzt
## Warning in if (extent == "device") {: Bedingung hat Länge > 1 und nur das
## erste Element wird benutzt
## Warning: Removed 2 rows containing missing values (geom_point).
train %>% dfSummary %>% view
## Switching method to 'browser'
## Output file written: /tmp/Rtmp2Ute1H/file325f2cfbb460.html
test %>% dfSummary %>% view
## Switching method to 'browser'
## Output file written: /tmp/Rtmp2Ute1H/file325f2d88f96f.html
Training on this data is problematic * Good quality training data only comes from a handful of countries –> cannot use country as a feature directly * Unclear how other features will translate to other countries * Most test set observations also don’t have data for the previous years * Prediction is essentially down to capacity, fuel_type and age of the plant
The following table contains total generation per country, we are explicitly allowed to use it, according to the data challenge instructions.
country_data <- read_csv('https://raw.githubusercontent.com/wri/global-power-plant-database/master/resources/generation_by_country_by_fuel_2014.csv')
## Parsed with column specification:
## cols(
## country = col_character(),
## fuel = col_character(),
## generation_gwh_2014 = col_double()
## )
country_data
Using this, we will build additional features: * Plant capacity as share of country’s total power generation * Plant capacity as share of coutnry’s total power generation of same fuel type
Check if country totals match total number:
cd_total_given <- country_data %>% filter(fuel=='Total') %>%
select(country, total_generation_given = generation_gwh_2014)
cd_total_given
cd_total_calculated <- country_data %>%
filter(fuel!='Total') %>%
group_by(country) %>%
summarize(total_generation_calculated=sum(generation_gwh_2014))
cd_total_calculated
cd_total_calculated %>% inner_join(cd_total_given) %>%
filter(total_generation_calculated != total_generation_given)
## Joining, by = "country"
country_data %>% filter(country == 'Niger')
Let’s fix that total mistake in the country data
country_data[country_data$country=='Niger' & country_data$fuel=='Total', ]$generation_gwh_2014 <- 690
See ?dplyr::join for more information about join types and basic relational algebra.
df_joined <- df %>%
left_join(
country_data %>% rename(country_gen_by_fuel=generation_gwh_2014),
by=c("country_long" = "country", "primary_fuel" = "fuel")
) %>%
left_join(
country_data %>% filter(fuel == 'Total') %>% transmute(country, country_gen_total = generation_gwh_2014),
by = c("country_long" = "country")
)
df_joined %>% count(is.na(country_gen_by_fuel))
df_joined %>% filter(is.na(country_gen_by_fuel)) %>% count(country_long, sort=T)
# missed
df_joined %>% filter(is.na(country_gen_by_fuel)) %>% count(primary_fuel)
# all
df %>% count(primary_fuel)
At this point we could go back and e.g. change all ‘Wave and Tidal’ plants to ‘Hydro’ if you believe that makes sense or take additional steps to augment our training data. Here we will skip this part.
df_joined <- df_joined %>%
# convert GWh to MW averages
mutate(
country_gen_by_fuel = 1000/24/365.25 * country_gen_by_fuel,
country_gen_total = 1000/24/365.25 * country_gen_total
) %>%
mutate(cap_share_of_country_gen_by_fuel = capacity_mw/country_gen_by_fuel,
cap_share_of_country_gen_total = capacity_mw / country_gen_total)
df_joined %>% summary
## country country_long name
## Length:29910 Length:29910 Length:29910
## Class :character Class :character Class :character
## Mode :character Mode :character Mode :character
##
##
##
##
## id capacity_mw latitude
## Length:29910 Min. : 1.000 Min. :-77.85
## Class :character 1st Qu.: 4.774 1st Qu.: 28.86
## Mode :character Median : 18.900 Median : 40.07
## Mean : 186.295 Mean : 32.50
## 3rd Qu.: 100.000 3rd Qu.: 47.13
## Max. :22500.000 Max. : 71.29
##
## longitude primary_fuel other_fuel1
## Min. :-179.978 Length:29910 Length:29910
## 1st Qu.: -79.211 Class :character Class :character
## Median : -3.747 Mode :character Mode :character
## Mean : -12.459
## 3rd Qu.: 24.931
## Max. : 179.389
##
## other_fuel2 other_fuel3 commission_date
## Length:29910 Length:29910 Min. :1896-01-01
## Class :character Class :character 1st Qu.:1985-12-31
## Mode :character Mode :character Median :2005-01-01
## Mean :1995-06-27
## 3rd Qu.:2012-01-24
## Max. :2018-01-01
## NA's :13607
## owner year_of_capacity_data generation_gwh_2013
## Length:29910 Min. :2000 Min. : -947.60
## Class :character 1st Qu.:2017 1st Qu.: 2.17
## Mode :character Median :2017 Median : 27.03
## Mean :2017 Mean : 689.63
## 3rd Qu.:2017 3rd Qu.: 250.60
## Max. :2018 Max. :50834.00
## NA's :16167 NA's :22914
## generation_gwh_2014 generation_gwh_2015 generation_gwh_2016
## Min. : -989.62 Min. : -864.43 Min. : -768.62
## 1st Qu.: 2.16 1st Qu.: 2.35 1st Qu.: 2.38
## Median : 23.00 Median : 22.09 Median : 17.86
## Mean : 664.84 Mean : 664.27 Mean : 583.62
## 3rd Qu.: 224.96 3rd Qu.: 238.44 3rd Qu.: 187.08
## Max. :32320.92 Max. :59546.86 Max. :32377.48
## NA's :22470 NA's :21766 NA's :20939
## generation_gwh_2017 country_gen_by_fuel country_gen_total
## Min. : -934.94 Min. : 0 Min. : 16.2
## 1st Qu.: 3.20 1st Qu.: 1968 1st Qu.: 38663.6
## Median : 20.46 Median : 4549 Median : 71617.0
## Mean : 579.52 Mean : 38608 Mean :244216.7
## 3rd Qu.: 192.57 3rd Qu.: 32116 3rd Qu.:495004.6
## Max. :35116.00 Max. :469452 Max. :647837.7
## NA's :20697 NA's :684 NA's :616
## cap_share_of_country_gen_by_fuel cap_share_of_country_gen_total
## Min. :0.0000 Min. :0.0000
## 1st Qu.:0.0008 1st Qu.:0.0000
## Median :0.0036 Median :0.0002
## Mean : Inf Mean :0.0101
## 3rd Qu.:0.0152 3rd Qu.:0.0013
## Max. : Inf Max. :6.2482
## NA's :684 NA's :616
Our last operation introduced Inf, that mlr can’t impute directly. Let’s fix them ourselves (here we arbitrarily choose 1.0, is that a good choice???)
df_joined <- df_joined %>% mutate(
cap_share_of_country_gen_by_fuel = if_else(is.infinite(cap_share_of_country_gen_by_fuel), 1.0, cap_share_of_country_gen_by_fuel)
)
Get rid of all columns that won’t be used in the model (except ID)
names(df_joined)
## [1] "country" "country_long"
## [3] "name" "id"
## [5] "capacity_mw" "latitude"
## [7] "longitude" "primary_fuel"
## [9] "other_fuel1" "other_fuel2"
## [11] "other_fuel3" "commission_date"
## [13] "owner" "year_of_capacity_data"
## [15] "generation_gwh_2013" "generation_gwh_2014"
## [17] "generation_gwh_2015" "generation_gwh_2016"
## [19] "generation_gwh_2017" "country_gen_by_fuel"
## [21] "country_gen_total" "cap_share_of_country_gen_by_fuel"
## [23] "cap_share_of_country_gen_total"
df_joined <- df_joined %>% select(-country, -country_long, -name, -latitude, -longitude,
-other_fuel2, -other_fuel3, -owner,
-generation_gwh_2013, -generation_gwh_2014, -generation_gwh_2015, -generation_gwh_2016,
-country_gen_by_fuel, -country_gen_total)
train <- df_joined %>% filter(!is.na(generation_gwh_2017))
test <- df_joined %>% anti_join(train) %>% select(-generation_gwh_2017)
## Joining, by = c("id", "capacity_mw", "primary_fuel", "other_fuel1", "commission_date", "year_of_capacity_data", "generation_gwh_2017", "cap_share_of_country_gen_by_fuel", "cap_share_of_country_gen_total")
dfSummary(train) %>% summarytools::view()
## Switching method to 'browser'
## Output file written: /tmp/Rtmp2Ute1H/file325f15ff93c4.html
dfSummary(test) %>% summarytools::view()
## Switching method to 'browser'
## Output file written: /tmp/Rtmp2Ute1H/file325f51d2ee30.html
library(mlr)
Let’s write a function that we will apply to both training and test set.
While we could do this manually for each column, mlr provides shortcuts that can save us some typing. There’s no clear boundaries what should be done here in the modeling section or what belongs in previous steps - it’s up to you.
Note: These are not supposed to be the best imputation methods for this dataset. Other ways may be better depending on the context of each column.
prepare_for_mlr <- function(df){
df_prepared <- df %>%
#select(-id) %>%
mutate_if(is.Date, decimal_date) %>%
mlr::impute(
classes = list(
# for each data type, specify a "standard" imputation method to apply
## As an example, we'll set NAs in character columns to 'unknown' and numeric columns to their mean
character = imputeConstant('none'),
integer = imputeMean(),
numeric = imputeMean()
),
cols = list(
# for columns that should NOT use the standard method based on its type, you can overwrite it
# example: let's impute missing `frequency` values with 0 instead of the mean:
cap_share_of_country_gen_by_fuel = imputeConstant(0),
cap_share_of_country_gen_total = imputeConstant(0)
)
) %>% .$data %>% as_tibble() %>% #return a data frame instead of mlr's imputation object
mutate_if(is.character, as_factor)
}
train_ids <- train$id
# remove ids before feeding to mlr
train <- prepare_for_mlr(train)
test <- prepare_for_mlr(test)
See mlr tutorials for details on tasks, learners, resampling methods, etc.
task <- makeRegrTask(
id = 'predict estimated power generation',
data = train %>% select(-id),
target = 'generation_gwh_2017'
)
## Warning in makeTask(type = type, data = data, weights = weights, blocking
## = blocking, : Provided data is not a pure data.frame but from class tbl_df,
## hence it will be converted.
List of possible learners: https://mlr.mlr-org.com/articles/tutorial/integrated_learners.html#regression-59
## this requires the package glmnet to be installed (but not loaded)
learner <- makeLearner(
id = 'my name for my learner',
cl = 'regr.rpart', #rpart regression tree, replace with model of your choice
predict.type = 'response',
fix.factors.prediction = TRUE, # deals with differences factor levels in train and test
par.vals = list() # set hyperparameters for the learner class given. See learner list above for possible values
)
Note:
Some learners don’t work with factors directly but require you to explicitly one-hot-encode first. You might get an error like
> Error in checkLearnerBeforeTrain(task, learner, weights) :
> Task 'predict future outages' has factor inputs in [columns], but learner [learner] does not support that!
If this happens to you, mlr provides a wrapper around the learner that can take care of OHE for you. See this line: learner <- makeDummyFeaturesWrapper(learner = learner, method = "1-of-n")
resampling_strategy <- makeResampleDesc("CV", iters = 5)
resample_result <- mlr::resample(
learner = learner,
task = task,
resampling = resampling_strategy,
measures = list(rmse, mae), # specify a list of mlr perormance 'Measures' that you're interested in.
# list of measures available at https://mlr.mlr-org.com/articles/tutorial/measures.html
keep.pred = TRUE
)
## Resampling: cross-validation
## Measures: rmse mae
## [Resample] iter 1: 850.0904414340.8927305
## [Resample] iter 2: 740.6603660304.5180408
## [Resample] iter 3: 958.5884331358.1965185
## [Resample] iter 4: 1140.4016788406.6053008
## [Resample] iter 5: 838.1771198339.6814100
##
## Aggregated Result: rmse.test.rmse=915.7707348,mae.test.mean=349.9788001
##
resample_result$runtime
## [1] 0.1897836
After resampling you might want to iterate and tune your learner (see advanced section on mlr website.) When you’re done, train your model.
model <- mlr::train(learner, task)
model
## Model for learner.id=my name for my learner; learner.class=regr.rpart
## Trained on: task.id = predict estimated power generation; obs = 9213; features = 7
## Hyperparameters: xval=0
PD Plots allow you to look at the influence of specific features on the target variable
See the ‘advanced’ tab on MLR website for more ways to analyze your model data.
pd <- generatePartialDependenceData(model, task) #,'primary_fuel')
## Loading required package: mmpf
## Warning in uniformGrid.factor(data[[vars]], n[1]): length.out is less than
## the number of levels
## Warning in uniformGrid.factor(data[[vars]], n[1]): length.out is less than
## the number of levels
plotPartialDependence(pd)
You can access the actual trained model which is wrapped by mlr using model$learner.model
model$learner.model
## n= 9213
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 9213 39208660000 579.5247
## 2) capacity_mw< 1072.65 8823 5100713000 259.0643
## 4) capacity_mw< 497.15 8273 665236100 122.9528 *
## 5) capacity_mw>=497.15 550 1976773000 2306.4290 *
## 3) capacity_mw>=1072.65 390 12703610000 7829.3260
## 6) capacity_mw< 1987.85 285 3084789000 5576.3330
## 12) primary_fuel=Oil,Gas,Hydro,Geothermal 109 758003800 3490.6710 *
## 13) primary_fuel=Coal,Nuclear,Petcoke 176 1558988000 6868.0220
## 26) primary_fuel=Coal,Petcoke 153 909545000 6266.9440 *
## 27) primary_fuel=Nuclear 23 226445000 10866.5000 *
## 7) capacity_mw>=1987.85 105 4245543000 13944.5900
## 14) cap_share_of_country_gen_by_fuel< 0.02228192 58 968686700 10547.3900 *
## 15) cap_share_of_country_gen_by_fuel>=0.02228192 47 1781442000 18136.8800
## 30) cap_share_of_country_gen_by_fuel>=0.05584242 11 358096200 11787.0000 *
## 31) cap_share_of_country_gen_by_fuel< 0.05584242 36 844292100 20077.1200 *
plot(model$learner.model, compress = TRUE)
text(model$learner.model, use.n = TRUE)
predict(model, task)$data %>%
ggplot(aes(x=truth, y=response)) + geom_point() +stat_function(fun=identity)
predictions <- predict(model, newdata=test)$data
## Warning in predict.WrappedModel(model, newdata = test): Provided data for
## prediction is not a pure data.frame but from class tbl_df, hence it will be
## converted.
test_predicted <- bind_cols(test, predictions)
submission <- test_predicted %>%
transmute(id, prediction=response) %>%
mutate(id = as.character(id)) %>%
arrange(id)
submission
Note: in the actual Analytics Cup, the id column might be either character or integer!
write_csv(submission, 'predictions_MyGroupName_SubmissionNumber.csv')